{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:30:41.922981Z",
     "iopub.status.busy": "2025-01-25T07:30:41.922736Z",
     "iopub.status.idle": "2025-01-25T07:30:43.819183Z",
     "shell.execute_reply": "2025-01-25T07:30:43.818575Z",
     "shell.execute_reply.started": "2025-01-25T07:30:41.922961Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=10, micro=14, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.0\n",
      "torch 2.5.1+cu124\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "    \n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n",
    "\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:30:43.820274Z",
     "iopub.status.busy": "2025-01-25T07:30:43.819983Z",
     "iopub.status.idle": "2025-01-25T07:31:05.048652Z",
     "shell.execute_reply": "2025-01-25T07:31:05.048103Z",
     "shell.execute_reply.started": "2025-01-25T07:30:43.820257Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2025-01-25 15:30:43--  https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt\n",
      "正在解析主机 storage.googleapis.com (storage.googleapis.com)... 172.217.14.251, 142.251.215.251, 142.250.217.123, ...\n",
      "正在连接 storage.googleapis.com (storage.googleapis.com)|172.217.14.251|:443... 已连接。\n",
      "已发出 HTTP 请求，正在等待回应... 200 OK\n",
      "长度： 1115394 (1.1M) [text/plain]\n",
      "正在保存至: ‘shakespeare.txt.1’\n",
      "\n",
      "shakespeare.txt.1   100%[===================>]   1.06M   500KB/s    用时 2.2s    \n",
      "\n",
      "2025-01-25 15:31:04 (500 KB/s) - 已保存 ‘shakespeare.txt.1’ [1115394/1115394])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.049424Z",
     "iopub.status.busy": "2025-01-25T07:31:05.049245Z",
     "iopub.status.idle": "2025-01-25T07:31:05.055313Z",
     "shell.execute_reply": "2025-01-25T07:31:05.054916Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.049405Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "length 1115394\n",
      "First Citizen:\n",
      "Before we proceed any further, hear me speak.\n",
      "\n",
      "All:\n",
      "Speak, speak.\n",
      "\n",
      "First Citizen:\n",
      "You\n"
     ]
    }
   ],
   "source": [
    "# https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt\n",
    "#文件已经下载好了\n",
    "with open(\"./shakespeare.txt\", \"r\", encoding=\"utf8\") as file:\n",
    "    text = file.read()\n",
    "\n",
    "print(\"length\", len(text))\n",
    "print(text[0:100])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造字典"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.056126Z",
     "iopub.status.busy": "2025-01-25T07:31:05.055778Z",
     "iopub.status.idle": "2025-01-25T07:31:05.071714Z",
     "shell.execute_reply": "2025-01-25T07:31:05.071225Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.056110Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65\n",
      "['\\n', ' ', '!', '$', '&', \"'\", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n"
     ]
    }
   ],
   "source": [
    "# 1. generate vocab\n",
    "# 2. build mapping char->id\n",
    "# 3. data -> id_data  把数据都转为id\n",
    "# 4. a b c d [EOS] -> [BOS] b c d  预测下一个字符生成的模型，也就是输入是a，输出就是b\n",
    "\n",
    "#去重，留下独立字符，并排序\n",
    "vocab = sorted(set(text))\n",
    "print(len(vocab))\n",
    "print(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.073198Z",
     "iopub.status.busy": "2025-01-25T07:31:05.072880Z",
     "iopub.status.idle": "2025-01-25T07:31:05.075488Z",
     "shell.execute_reply": "2025-01-25T07:31:05.075121Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.073183Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 how\n",
      "1 are\n",
      "2 you\n"
     ]
    }
   ],
   "source": [
    "for idx,char in enumerate(['how','are','you']):\n",
    "    print(idx,char)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.076171Z",
     "iopub.status.busy": "2025-01-25T07:31:05.075916Z",
     "iopub.status.idle": "2025-01-25T07:31:05.078593Z",
     "shell.execute_reply": "2025-01-25T07:31:05.078216Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.076156Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'\\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, \"'\": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}\n"
     ]
    }
   ],
   "source": [
    "#每个字符都编好号，enumerate对每一个位置编号，生成的是列表中是元组，下面字典生成式\n",
    "char2idx = {char:idx for idx, char in enumerate(vocab)}\n",
    "print(char2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.079340Z",
     "iopub.status.busy": "2025-01-25T07:31:05.079008Z",
     "iopub.status.idle": "2025-01-25T07:31:05.081923Z",
     "shell.execute_reply": "2025-01-25T07:31:05.081476Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.079325Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n",
      " 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n",
      " 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n",
      " 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n"
     ]
    }
   ],
   "source": [
    "# 把vocab从列表变为ndarray\n",
    "idx2char = np.array(vocab)\n",
    "print(idx2char)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.082495Z",
     "iopub.status.busy": "2025-01-25T07:31:05.082362Z",
     "iopub.status.idle": "2025-01-25T07:31:05.164961Z",
     "shell.execute_reply": "2025-01-25T07:31:05.164579Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.082482Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1115394,)\n",
      "1115394\n",
      "[18 47 56 57 58  1 15 47 58 47]\n",
      "First Citi\n"
     ]
    }
   ],
   "source": [
    "#把字符都转换为id\n",
    "text_as_int = np.array([char2idx[c] for c in text])\n",
    "print(text_as_int.shape)\n",
    "print(len(text_as_int))\n",
    "print(text_as_int[0:10])\n",
    "print(text[0:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 把莎士比亚文集分成一个一个的样本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.165581Z",
     "iopub.status.busy": "2025-01-25T07:31:05.165422Z",
     "iopub.status.idle": "2025-01-25T07:31:05.170725Z",
     "shell.execute_reply": "2025-01-25T07:31:05.170234Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.165562Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class CharDataset(Dataset):\n",
    "    def __init__(self, text_as_int, seq_length):\n",
    "        self.sub_len = seq_length + 1\n",
    "        self.text_as_int = text_as_int\n",
    "        self.num_seq = len(text_as_int) // self.sub_len\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        return self.text_as_int[index * self.sub_len: (index + 1) * self.sub_len]\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.num_seq\n",
    "    \n",
    "def collat_fct(batch):\n",
    "    src_list = []\n",
    "    trg_list = []\n",
    "    for part in batch:\n",
    "        src_list.append(part[:-1])\n",
    "        trg_list.append(part[1:])\n",
    "        \n",
    "    src_list = np.array(src_list)\n",
    "    trg_list = np.array(trg_list)\n",
    "    return torch.Tensor(src_list).to(dtype=torch.int64), torch.Tensor(trg_list).to(dtype=torch.int64)\n",
    "        \n",
    "\n",
    "train_ds = CharDataset(text_as_int, 100)\n",
    "train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collat_fct)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.171507Z",
     "iopub.status.busy": "2025-01-25T07:31:05.171186Z",
     "iopub.status.idle": "2025-01-25T07:31:05.309520Z",
     "shell.execute_reply": "2025-01-25T07:31:05.309039Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.171491Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================== 一层单向 LSTM ===================================\n",
      "            embedding.weight            paramerters num: 16640\n",
      "           lstm.weight_ih_l0            paramerters num: 1048576\n",
      "           lstm.weight_hh_l0            paramerters num: 4194304\n",
      "            lstm.bias_ih_l0             paramerters num: 4096\n",
      "            lstm.bias_hh_l0             paramerters num: 4096\n",
      "               fc.weight                paramerters num: 66560\n",
      "                fc.bias                 paramerters num: 65\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 100, 65])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CharLSTM(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=1024):\n",
    "        super(CharLSTM, self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
    "        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n",
    "        self.fc = nn.Linear(hidden_dim, vocab_size)\n",
    "        \n",
    "    def forward(self, x, hidden=None):\n",
    "        x = self.embedding(x)\n",
    "        output, hidden = self.lstm(x, hidden)\n",
    "        x = self.fc(output)\n",
    "        return x, hidden\n",
    "    \n",
    "    \n",
    "vocab_size = len(vocab)\n",
    "sample_inputs = torch.randint(0, vocab_size, (2, 100))\n",
    "    \n",
    "print(\"{:=^80}\".format(\" 一层单向 LSTM \"))       \n",
    "for key, value in CharLSTM(vocab_size).named_parameters():\n",
    "    print(f\"{key:^40}paramerters num: {np.prod(value.shape)}\")\n",
    "    \n",
    "CharLSTM(vocab_size)(sample_inputs)[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-12-13T02:36:48.216578Z",
     "start_time": "2023-12-13T02:36:48.192464200Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.310264Z",
     "iopub.status.busy": "2025-01-25T07:31:05.310092Z",
     "iopub.status.idle": "2025-01-25T07:31:05.313337Z",
     "shell.execute_reply": "2025-01-25T07:31:05.312890Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.310248Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1048576"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "4 * 1024*256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-12-13T02:37:02.308627800Z",
     "start_time": "2023-12-13T02:37:02.296635500Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.314209Z",
     "iopub.status.busy": "2025-01-25T07:31:05.313826Z",
     "iopub.status.idle": "2025-01-25T07:31:05.316926Z",
     "shell.execute_reply": "2025-01-25T07:31:05.316502Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.314194Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4194304"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "4 * 1024*1024"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.317712Z",
     "iopub.status.busy": "2025-01-25T07:31:05.317390Z",
     "iopub.status.idle": "2025-01-25T07:31:05.321661Z",
     "shell.execute_reply": "2025-01-25T07:31:05.321273Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.317697Z"
    }
   },
   "outputs": [],
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=5000, save_best_only=True):\n",
    "        \"\"\"\n",
    "        Save checkpoints each save_epoch epoch. \n",
    "        We save checkpoint by epoch in this implementation.\n",
    "        Usually, training scripts with pytorch evaluating model and save checkpoint by step.\n",
    "\n",
    "        Args:\n",
    "            save_dir (str): dir to save checkpoint\n",
    "            save_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.\n",
    "            save_best_only (bool, optional): If True, only save the best model or save each model at every epoch.\n",
    "        \"\"\"\n",
    "        self.save_dir = save_dir\n",
    "        self.save_step = save_step\n",
    "        self.save_best_only = save_best_only\n",
    "        self.best_metrics = -1\n",
    "        \n",
    "        # mkdir\n",
    "        if not os.path.exists(self.save_dir):\n",
    "            os.mkdir(self.save_dir)\n",
    "        \n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step > 0:\n",
    "            return\n",
    "        \n",
    "        if self.save_best_only:\n",
    "            assert metric is not None\n",
    "            if metric >= self.best_metrics:\n",
    "                # save checkpoints\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, \"best.ckpt\"))\n",
    "                # update best metrics\n",
    "                self.best_metrics = metric\n",
    "        else:\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{step}.ckpt\"))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:31:05.323411Z",
     "iopub.status.busy": "2025-01-25T07:31:05.323110Z",
     "iopub.status.idle": "2025-01-25T07:36:10.087490Z",
     "shell.execute_reply": "2025-01-25T07:36:10.087059Z",
     "shell.execute_reply.started": "2025-01-25T07:31:05.323396Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 17300/17300 [05:03<00:00, 56.92it/s, epoch=99]\n"
     ]
    }
   ],
   "source": [
    "# 训练\n",
    "def training(\n",
    "    model, \n",
    "    train_loader, \n",
    "    epoch, \n",
    "    loss_fct, \n",
    "    optimizer, \n",
    "    save_ckpt_callback=None,\n",
    "    stateful=False      # 想用stateful，batch里的数据就必须连续，不能打乱\n",
    "    ):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "    }\n",
    "    \n",
    "    global_step = 0\n",
    "    model.train()\n",
    "    hidden = None\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            # training\n",
    "            for datas, labels in train_loader:\n",
    "                datas = datas.to(device)\n",
    "                labels = labels.to(device)\n",
    "                # 梯度清空\n",
    "                optimizer.zero_grad()\n",
    "                # 模型前向计算\n",
    "                logits, hidden = model(datas, hidden=hidden if stateful else None)\n",
    "                # 计算损失\n",
    "                loss = loss_fct(logits.reshape(-1, vocab_size), labels.reshape(-1))\n",
    "                # 梯度回传\n",
    "                loss.backward()\n",
    "                # 调整优化器，包括学习率的变动等\n",
    "                optimizer.step()\n",
    " \n",
    "                loss = loss.cpu().item()\n",
    "                # record\n",
    "                \n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss, \"step\": global_step\n",
    "                })\n",
    "   \n",
    "                # 保存模型权重 save model checkpoint\n",
    "                if save_ckpt_callback is not None:\n",
    "                    save_ckpt_callback(global_step, model.state_dict(), metric=-loss)\n",
    "                # udate step\n",
    "                global_step += 1\n",
    "                pbar.update(1)\n",
    "                pbar.set_postfix({\"epoch\": epoch_id})\n",
    "        \n",
    "    return record_dict\n",
    "        \n",
    "\n",
    "epoch = 100\n",
    "\n",
    "model = CharLSTM(vocab_size=vocab_size)\n",
    "\n",
    "# 1. 定义损失函数 采用交叉熵损失 \n",
    "loss_fct = nn.CrossEntropyLoss()\n",
    "# 2. 定义优化器 采用 adam\n",
    "# Optimizers specified in the torch.optim package\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "\n",
    "# save best\n",
    "if not os.path.exists(\"checkpoints\"):\n",
    "    os.makedirs(\"checkpoints\")\n",
    "save_ckpt_callback = SaveCheckpointsCallback(\"checkpoints/text_generation_lstm\", save_step=1000, save_best_only=True)\n",
    "\n",
    "\n",
    "model = model.to(device)\n",
    "record = training(\n",
    "    model, \n",
    "    train_dl, \n",
    "    epoch, \n",
    "    loss_fct, \n",
    "    optimizer, \n",
    "    save_ckpt_callback=save_ckpt_callback,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:36:10.088341Z",
     "iopub.status.busy": "2025-01-25T07:36:10.088008Z",
     "iopub.status.idle": "2025-01-25T07:36:10.206639Z",
     "shell.execute_reply": "2025-01-25T07:36:10.206098Z",
     "shell.execute_reply.started": "2025-01-25T07:36:10.088312Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAikAAAGdCAYAAADXIOPgAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAWsFJREFUeJzt3XlYU1f+BvA3YQlrQEB2UNw3cMENrdW2uNWx2sVprVO7qF1Gp3acsf0502m1narT2jqdacfutZu1tVN12rohilbBBRQEFxREQPY1YQ2BnN8fIReiQMuiuZj38zw8MTf3JidfA7ycc+65CiGEABEREZHMKC3dACIiIqKWMKQQERGRLDGkEBERkSwxpBAREZEsMaQQERGRLDGkEBERkSwxpBAREZEsMaQQERGRLNlaugG/hsFgQG5uLlxdXaFQKCzdHCIiIvoVhBCoqKiAv78/lMr294t0i5CSm5uLoKAgSzeDiIiIOiA7OxuBgYHtPq5bhBRXV1cAxjepVqu77Hn1ej327duHadOmwc7Orsuet7thHVgDgDUwYR1YA4A1MOlsHbRaLYKCgqTf4+3VLUKKaYhHrVZ3eUhxcnKCWq22+g+htdeBNWANTFgH1gBgDUy6qg4dnarBibNEREQkSwwpREREJEsMKURERCRLDClEREQkSwwpREREJEsMKURERCRLDClEREQkSwwpREREJEsMKURERCRLDClEREQkSwwpREREJEsMKURERCRLVh1SPo3NxH8zlEjNr7B0U4iIiOgaVh1SdqXk43C+EtllNZZuChEREV3DqkOKsvHS0UJYuCFERER0HSsPKcZbA1MKERGR7Fh1SFE09qQwpBAREcmPVYcUU08KMwoREZH8WHlIYU8KERGRXFl1SFFIc1Is2w4iIiK6nlWHlKaze5hSiIiI5MbKQ4rxlj0pRERE8mPVIUUBzkkhIiKSK+sOKexJISIiki2rDimmOSkAUwoREZHcdCqkrF+/HgqFAs8991yb+23btg2DBg2Cg4MDQkNDsWvXrs68bJfhnBQiIiL56nBIOXnyJN5//32EhYW1uV9sbCzmz5+PRYsW4fTp05g7dy7mzp2LlJSUjr50l+GKs0RERPLVoZBSWVmJBQsW4MMPP0SPHj3a3Pftt9/GjBkzsHLlSgwePBivvvoqRo0ahXfeeadDDe5K7EkhIiKSL9uOHLR06VLMmjULkZGR+Pvf/97mvnFxcVixYoXZtunTp2PHjh2tHqPT6aDT6aT7Wq0WAKDX66HX6zvS5FYY00l9fX0XP2/3YnrvrAFr0PzWWrEOrAHAGph0tg6drV+7Q8rWrVtx6tQpnDx58lftn5+fDx8fH7NtPj4+yM/Pb/WYdevWYc2aNddt37dvH5ycnNrX4DYUFigBKHH+/AXsKj/fZc/bXUVFRVm6CRbHGrAGJqwDawCwBiYdrUN1dXWnXrddISU7OxvLly9HVFQUHBwcOvXCbVm1apVZ74tWq0VQUBCmTZsGtVrdZa+zryIRp0sKMWDgQNx9W58ue97uRq/XIyoqClOnToWdnZ2lm2MRrAFrYMI6sAYAa2DS2TqYRkI6ql0hJSEhAYWFhRg1apS0raGhAYcPH8Y777wDnU4HGxsbs2N8fX1RUFBgtq2goAC+vr6tvo5KpYJKpbpuu52dXZd+WJRK45QchdLGqj+EJl1d3+6INWANTFgH1gBgDUw6WofO1q5dE2fvuusuJCcnIzExUfoaPXo0FixYgMTExOsCCgBEREQgOjrabFtUVBQiIiI61fCuwKsgExERyVe7elJcXV0xbNgws23Ozs7w9PSUti9cuBABAQFYt24dAGD58uWYPHky3nzzTcyaNQtbt25FfHw8Pvjggy56Cx1nOruHGYWIiEh+unzF2aysLOTl5Un3J0yYgC1btuCDDz7A8OHD8d1332HHjh3XhR1L4DopRERE8tWhU5Cbi4mJafM+AMybNw/z5s3r7Et1OdNwDzMKERGR/Fj5tXuMt+xJISIikh+rDilNwz0WbggRERFdx6pDCntSiIiI5MvKQ4ppTgpDChERkdxYeUgx3nK4h4iISH6sOqSApyATERHJllWHFFNPCphRiIiIZMfKQwrP7iEiIpIrKw8pxlsO9xAREcmPVYcULotPREQkX1YdUniBQSIiIvmy8pDCnhQiIiK5suqQouA6KURERLJl1SGFK84SERHJl1WHFNMyKexJISIikh+rDilSTwpXcyMiIpIdqw4pnJNCREQkX1YdUjgnhYiISL6sPKQYb9mTQkREJD9WHVK44iwREZF8WXVIUTa+e/akEBERyY91hxTTnBSmFCIiItlhSAF7UoiIiOTIqkOKCeekEBERyY9VhxReBZmIiEi+rDykcMVZIiIiubLykGK85ZwUIiIi+bHqkMJ1UoiIiOTLqkNK07L4Fm4IERERXcfKQ4rxlj0pRERE8mPVIYXDPURERPJl1SGFE2eJiIjkq10hZdOmTQgLC4NarYZarUZERAR2797d6v6bN2+GQqEw+3JwcOh0o7uKQlonhSmFiIhIbmzbs3NgYCDWr1+P/v37QwiBzz77DHPmzMHp06cxdOjQFo9Rq9VITU2V7puGWOSAy+ITERHJV7tCyuzZs83uv/baa9i0aROOHTvWakhRKBTw9fXteAtvIAXP7iEiIpKtdoWU5hoaGrBt2zZUVVUhIiKi1f0qKyvRq1cvGAwGjBo1CmvXrm010JjodDrodDrpvlarBQDo9Xro9fqONvk6wtAAAGgwGLr0ebsb03tnDViD5rfWinVgDQDWwKSzdehs/RSinRMykpOTERERgdraWri4uGDLli24++67W9w3Li4Oly5dQlhYGDQaDTZs2IDDhw/j7NmzCAwMbPU1Vq9ejTVr1ly3fcuWLXBycmpPc9t0qliBzy7ZoL/agGVDDV32vERERARUV1fj4YcfhkajgVqtbvfx7Q4pdXV1yMrKgkajwXfffYePPvoIhw4dwpAhQ37xWL1ej8GDB2P+/Pl49dVXW92vpZ6UoKAgFBcXd+hNtuaHpBys+O4sxvRyx5bFY7vsebsbvV6PqKgoTJ06FXZ2dpZujkWwBqyBCevAGgCsgUln66DVauHl5dXhkNLu4R57e3v069cPABAeHo6TJ0/i7bffxvvvv/+Lx9rZ2WHkyJFIS0trcz+VSgWVStXi8V35YbGzNb590fjc1q6r69sdsQasgQnrwBoArIFJR+vQ2dp1ep0Ug8Fg1uvRloaGBiQnJ8PPz6+zL9sllNIpyJZtBxEREV2vXT0pq1atwsyZMxEcHIyKigps2bIFMTEx2Lt3LwBg4cKFCAgIwLp16wAAr7zyCsaPH49+/fqhvLwcb7zxBjIzM7F48eKufycdoOSKs0RERLLVrpBSWFiIhQsXIi8vD25ubggLC8PevXsxdepUAEBWVhaUyqbOmbKyMixZsgT5+fno0aMHwsPDERsb+6vmr9wMCq44S0REJFvtCikff/xxm4/HxMSY3d+4cSM2btzY7kbdLE3rpDClEBERyQ2v3QPjxFkiIiKSFysPKZyTQkREJFdWHVKkOSlcx42IiEh2rDqkKDknhYiISLasPKQYb3l2DxERkfxYeUjhnBQiIiK5suqQwnVSiIiI5MuqQwrnpBAREcmXVYeUxo4U9qQQERHJkFWHFM5JISIiki+rDikKrjhLREQkW1YdUjgnhYiISL4YUsA5KURERHJk1SGl6RRkphQiIiK5seqQ0jTcY+GGEBER0XWsPKQYb9mTQkREJD9WHlJ4CjIREZFcWXVIMa3mxoxCREQkP1YdUtiTQkREJF9WHlKMt8woRERE8mPlIYXrpBAREcmVVYcUaVl8dqUQERHJjlWHFM5JISIiki8rDynGWw73EBERyY9VhxQFe1KIiIhky6pDCs/uISIiki+rDinsSSEiIpIvqw4pnJNCREQkX1YdUkw9KQBPQyYiIpIbqw4pyqaMwnkpREREMmPlIaUppXBeChERkby0K6Rs2rQJYWFhUKvVUKvViIiIwO7du9s8Ztu2bRg0aBAcHBwQGhqKXbt2darBXal5TwrnpRAREclLu0JKYGAg1q9fj4SEBMTHx+POO+/EnDlzcPbs2Rb3j42Nxfz587Fo0SKcPn0ac+fOxdy5c5GSktIlje8sBXtSiIiIZKtdIWX27Nm4++670b9/fwwYMACvvfYaXFxccOzYsRb3f/vttzFjxgysXLkSgwcPxquvvopRo0bhnXfe6ZLGdxbnpBAREclXh+ekNDQ0YOvWraiqqkJERESL+8TFxSEyMtJs2/Tp0xEXF9fRl+1SnJNCREQkX7btPSA5ORkRERGora2Fi4sLtm/fjiFDhrS4b35+Pnx8fMy2+fj4ID8/v83X0Ol00Ol00n2tVgsA0Ov10Ov17W1yqxrq65tes04Pe6V1BhVTTbuytt0Na8AamLAOrAHAGph0tg6drV+7Q8rAgQORmJgIjUaD7777Do8++igOHTrUalDpiHXr1mHNmjXXbd+3bx+cnJy67HXqDYCpBHv37YNTu6txa4mKirJ0EyyONWANTFgH1gBgDUw6Wofq6upOvW67fy3b29ujX79+AIDw8HCcPHkSb7/9Nt5///3r9vX19UVBQYHZtoKCAvj6+rb5GqtWrcKKFSuk+1qtFkFBQZg2bRrUanV7m9yqGp0OOH4IADB16lS4Odp12XN3J3q9HlFRUZg6dSrs7FgD1sB6awCwDgBrALAGJp2tg2kkpKM63XdgMBjMhmaai4iIQHR0NJ577jlpW1RUVKtzWExUKhVUKtV12+3s7Lr0w2Jodt6xjY2tVX8Qga6vb3fEGrAGJqwDawCwBiYdrUNna9eukLJq1SrMnDkTwcHBqKiowJYtWxATE4O9e/cCABYuXIiAgACsW7cOALB8+XJMnjwZb775JmbNmoWtW7ciPj4eH3zwQaca3VUUZuukWOd8FCIiIrlqV0gpLCzEwoULkZeXBzc3N4SFhWHv3r2YOnUqACArKwtKZdMJQxMmTMCWLVvw4osv4i9/+Qv69++PHTt2YNiwYV37LjpIoVBAAQEBBRdzIyIikpl2hZSPP/64zcdjYmKu2zZv3jzMmzevXY26mRQABHiBQSIiIrmx6mv3AE1DPuxJISIikheGlMZbzkkhIiKSF4YUqSeFIYWIiEhOGFIab5lRiIiI5IUhpfGWPSlERETywpDSmFKYUYiIiOSFIaXxlj0pRERE8sKQwlOQiYiIZIkhpfGWi7kRERHJC0MKe1KIiIhkyepDiqkAnJNCREQkL1YfUjhxloiISJ6sPqSApyATERHJktWHFPakEBERyZPVhxQle1KIiIhkyepDCntSiIiI5IkhpfGWpyATERHJC0OKNNzDlEJERCQnDCmNt+xJISIikheGFGnFWaYUIiIiOWFIabxlSCEiIpIXhpTGW2YUIiIieWFI4XAPERGRLDGkNN4yoxAREckLQwp7UoiIiGSJIaXxlhmFiIhIXqw+pCjZk0JERCRLVh9SuJgbERGRPDGksCeFiIhIlhhSGm957R4iIiJ5YUhpvOVwDxERkbwwpHC4h4iISJbaFVLWrVuHMWPGwNXVFd7e3pg7dy5SU1PbPGbz5s1QKBRmXw4ODp1qdFdSwBhO2JNCREQkL+0KKYcOHcLSpUtx7NgxREVFQa/XY9q0aaiqqmrzOLVajby8POkrMzOzU43uSqaeFM5JISIikhfb9uy8Z88es/ubN2+Gt7c3EhIScPvtt7d6nEKhgK+vb8daeINxMTciIiJ5aldIuZZGowEAeHh4tLlfZWUlevXqBYPBgFGjRmHt2rUYOnRoq/vrdDrodDrpvlarBQDo9Xro9frONNmMXq+XQoq+vr5Ln7s7Mb1va33/AGsAsAYmrANrALAGJp2tQ2frpxAdHOcwGAy45557UF5ejiNHjrS6X1xcHC5duoSwsDBoNBps2LABhw8fxtmzZxEYGNjiMatXr8aaNWuu275lyxY4OTl1pLmteu+8EufLlVjQtwFjvdmdQkRE1FWqq6vx8MMPQ6PRQK1Wt/v4DoeUZ555Brt378aRI0daDRst0ev1GDx4MObPn49XX321xX1a6kkJCgpCcXFxh95kW2154F/ROFeuxLp7h+KBUQFd9tzdiV6vR1RUFKZOnQo7OztLN8ciWAPWwIR1YA0A1sCks3XQarXw8vLqcEjp0HDPsmXL8OOPP+Lw4cPtCigAYGdnh5EjRyItLa3VfVQqFVQqVYvHdvWHxXTtHqVSadUfRODG1Le7YQ1YAxPWgTUAWAOTjtahs7Vr19k9QggsW7YM27dvx4EDBxASEtLuF2xoaEBycjL8/PzafeyNwMXciIiI5KldPSlLly7Fli1bsHPnTri6uiI/Px8A4ObmBkdHRwDAwoULERAQgHXr1gEAXnnlFYwfPx79+vVDeXk53njjDWRmZmLx4sVd/FY6h4u5ERERyUu7QsqmTZsAAFOmTDHb/umnn+Kxxx4DAGRlZUGpbOqgKSsrw5IlS5Cfn48ePXogPDwcsbGxGDJkSOda3kWaVpy1bDuIiIjIXLtCyq+ZYxsTE2N2f+PGjdi4cWO7GnUzSXGKPSlERESywmv3sCeFiIhIlhhSGm85J4WIiEheGFLYk0JERCRLDCmNt7zAIBERkbwwpEg9KQwpREREcsKQ0njL4R4iIiJ5YUhpvGVPChERkbwwpDSmFGYUIiIieWFIabzlxFkiIiJ5YUjhKchERESyxJDSeMs5KURERPJi9SHFVAD2pBAREcmL1YeUpomzTClERERywpDSeMvhHiIiInlhSOHEWSIiIlmy+pBiwp4UIiIiebH6kGIqADMKERGRvFh9SJGGezjeQ0REJCsMKY23jChERETywpAiTZxlTCEiIpIThpTGW2YUIiIieWFIYU8KERGRLDGkNM5GYUghIiKSF6sPKUou5kZERCRLVh9STHjtHiIiInmx+pAiXbvHYNFmEBER0TWsPqSYhnvqOd5DREQkK1YfUhxsjLeVOr1lG0JERERmrD6kONkab7U19ZZtCBEREZmx+pDi2NiToqlhTwoREZGcMKTYGueiaGsZUoiIiOSEIUUa7mFIISIikpN2hZR169ZhzJgxcHV1hbe3N+bOnYvU1NRfPG7btm0YNGgQHBwcEBoail27dnW4wV3NNNxToauHgWf4EBERyUa7QsqhQ4ewdOlSHDt2DFFRUdDr9Zg2bRqqqqpaPSY2Nhbz58/HokWLcPr0acydOxdz585FSkpKpxvfFUwTZ4UwBhUiIiKSB9v27Lxnzx6z+5s3b4a3tzcSEhJw++23t3jM22+/jRkzZmDlypUAgFdffRVRUVF455138N5773Ww2V3HVgk42ClRqzdAW6OHm6OdpZtEREREaGdIuZZGowEAeHh4tLpPXFwcVqxYYbZt+vTp2LFjR6vH6HQ66HQ66b5WqwUA6PV66PVdN3fE9FyuKlvU6utQUlEDX1frCymmOnRlbbsb1oA1MGEdWAOANTDpbB06W78OhxSDwYDnnnsOEydOxLBhw1rdLz8/Hz4+PmbbfHx8kJ+f3+ox69atw5o1a67bvm/fPjg5OXW0ya2yadABUGD/oaPIdLPeeSlRUVGWboLFsQasgQnrwBoArIFJR+tQXV3dqdftcEhZunQpUlJScOTIkU41oCWrVq0y633RarUICgrCtGnToFaru+x19Ho9oqKi4OfphvyrWgwKG4XpQ31++cBbjKkOU6dOhZ2d9fUkAawBwBqYsA6sAcAamHS2DqaRkI7qUEhZtmwZfvzxRxw+fBiBgYFt7uvr64uCggKzbQUFBfD19W31GJVKBZVKdd12Ozu7G/JhUTvZAwCq9cKqP4w3qr7dCWvAGpiwDqwBwBqYdLQOna1du87uEUJg2bJl2L59Ow4cOICQkJBfPCYiIgLR0dFm26KiohAREdG+lt5Abg7GInLVWSIiIvloV0/K0qVLsWXLFuzcuROurq7SvBI3Nzc4OjoCABYuXIiAgACsW7cOALB8+XJMnjwZb775JmbNmoWtW7ciPj4eH3zwQRe/lY5TN67oxlVniYiI5KNdPSmbNm2CRqPBlClT4OfnJ31988030j5ZWVnIy8uT7k+YMAFbtmzBBx98gOHDh+O7777Djh072pxse7OpG3tSuOosERGRfLSrJ0WIXz7zJSYm5rpt8+bNw7x589rzUjeVqSeFwz1ERETyYfXX7gEAtYNpuIcrzhIREckFQwqahnvYk0JERCQfDCloNnGWIYWIiEg2GFLAnhQiIiI5YkgB4O5kDCnl1fpfNTmYiIiIbjyGFAAejSvO1jUYUKnj5FkiIiI5YEgB4GhvAyd7GwBASWWdhVtDREREAEOKxMPZ2JtSUsWQQkREJAcMKY08G0NKKUMKERGRLDCkNPJ0MV51ubRKZ+GWEBEREcCQIjEN9xRzTgoREZEsMKQ04nAPERGRvDCkNPJgSCEiIpIVhpRGpjkpxZWck0JERCQHDCmNONxDREQkLwwpjTjcQ0REJC8MKY08XRoXc6us4/V7iIiIZIAhpZGns3FOCq/fQ0REJA8MKY0c7W3gaGe8fg+HfIiIiCyPIaUZ05BPgZZn+BAREVkaQ0ozA31cAQDJORoLt4SIiIgYUpoZ1asHAOBUVpmFW0JEREQMKc2MDHIHACRmlVu0HURERMSQYiYsyB1KBZBTXoMCba2lm0NERGTVGFKacVHZYkDjvJTTHPIhIiKyKIaUa5jmpZzmkA8REZFFMaRcIzTADQBwNldr4ZYQERFZN4aUawz1VwMAzuZquDw+ERGRBTGkXGOAjytslAqUVeuRp+HkWSIiIkthSLmGg50N+nu7AOCQDxERkSUxpLRgSLMhHyIiIrIMhpQWDPXn5FkiIiJLa3dIOXz4MGbPng1/f38oFArs2LGjzf1jYmKgUCiu+8rPz+9om2840+TZcwwpREREFtPukFJVVYXhw4fj3XffbddxqampyMvLk768vb3b+9I3jWm4J6e8BmVVdRZuDRERkXWybe8BM2fOxMyZM9v9Qt7e3nB3d2/3cZagdrBDsIcTskqrcS5Pi4n9vCzdJCIiIqvT7pDSUSNGjIBOp8OwYcOwevVqTJw4sdV9dToddDqddF+rNQ676PV66PX6LmuT6blaes7Bvi7IKq3GmewyjO3l1mWvKUdt1cFasAasgQnrwBoArIFJZ+vQ2fopRCdWLFMoFNi+fTvmzp3b6j6pqamIiYnB6NGjodPp8NFHH+GLL77A8ePHMWrUqBaPWb16NdasWXPd9i1btsDJyamjzW2XfVcV+CnbBqM8DXh0gOGmvCYREdGtpLq6Gg8//DA0Gg3UanW7j7/hIaUlkydPRnBwML744osWH2+pJyUoKAjFxcUdepOt0ev1iIqKwtSpU2FnZ2f22KGLRVj8xWnY2yqx4f5hmDrYG7Y2t+bJUG3VwVqwBqyBCevAGgCsgUln66DVauHl5dXhkHLThnuaGzt2LI4cOdLq4yqVCiqV6rrtdnZ2N+TD0tLzhgV7AADq6g149pszeHXOUDwS0bvLX1tOblR9uxPWgDUwYR1YA4A1MOloHTpbO4t0DSQmJsLPz88SL/2rebs6YIhfU+o7nlFqwdYQERFZn3b3pFRWViItLU26n5GRgcTERHh4eCA4OBirVq1CTk4OPv/8cwDAP//5T4SEhGDo0KGora3FRx99hAMHDmDfvn1d9y5ukC1LxmF3Sj5WfZ/Mhd2IiIhusnaHlPj4eNxxxx3S/RUrVgAAHn30UWzevBl5eXnIysqSHq+rq8Of/vQn5OTkwMnJCWFhYdi/f7/Zc8iVu5M9pg/1xarvk5FRXIXfvh8HW6UCmx8fC3vbW3N+ChERkVy0O6RMmTIFbc213bx5s9n9559/Hs8//3y7GyYXHs72CHB3RE55DU40DvlEny/AzFB5D1cRERF1d+wO+BWGBZjPSN5yIquVPYmIiKirMKT8CoE9zNdmOZJWjOzSagu1hoiIyDowpPwK9wz3h0IBRA72wW39vCAE8G18tqWbRUREdEtjSPkVhge5I+bPU/DOwyPx0NggAMaQUt9gwI9ncvHSzhTU6hss3EoiIqJbi0UWc+uOenk6AwCmDvGBh7M9CrQ6jHw1ChW19QCAsSEe+E2YvyWbSEREdEthT0o7qWxt8EB4IABIAQUALuRVWKpJREREtyT2pHTA05P7orSqDq4Otsgpq8G+cwVILWBIISIi6koMKR3g4WyPDfOGAwBi04uNISXfGFL+l5SLpOxy/N/MQbC7RS9ISEREdDMwpHTSQB9XAEBWaTXyNDV4/rsk1OoNCAt0w5wRARZuHRERUffFP/U7ydNFBS8X4xWb1+++gFq9AQDwQ1KuJZtFRETU7TGkdIGBvi4AgJ2JTcFk//lC7D9XgPLqOks1i4iIqFtjSOkCA32als13Udmib0/j6cqLP4/H7z4+3ua1joiIiKhlDCldYN7oQAz2U2PaEB989sRY/H5KP+mxlBwt3jt0GWt3nUdZFXtViIiIfi1OnO0Cg/3U2L18knQ/vFcPzB7uj3/suYCPj2TgH3suAAAc7Wzwx6kDLNVMIiKiboU9KTeIva0ST9wWAhulQtoWm14s/TujuAr3/eco9qTkWaJ5REREsseQcgMFuDvizXnDce9I46nISdka6Ro/iz87iVNZ5Xj6y1OWbCIREZFscbjnBps7MgBzRvgjNr0YBVodFn58An29nZFeVCXtYzAIKJv1uBARERF7Um4KhUKBcSGeAIATV0rx9Ylss8ezy6ot0SwiIiJZY0i5SSb09ZT+7epg3oH16dErWL/7AuKvlN7sZhEREckWh3tukrkjA5BaUIEpA73Rt6czTmSUIi69BNsSrmJz7BUAwHuH0rF69hA8NjHEso0lIiKSAYaUm8TBzgYvzx4q3Q/s4YRKXT22JVw12+/9w5fxSERvs7OCiIiIrBGHeyxoiF/TSrWvzBkKN0c75GlqcTStGIcvFmHS6wdw/6ZY7ErmacpERGR92JNiQUP93RDg7gi1ox0eGhOMtMJKfB6XiW0JV2EQAtmlNcgurUFGcRWmDfGBrQ0zJRERWQ/+1rMgR3sbHFo5BTuXToS9rRL3jwoEAESfL0DyVY20X2lVHU5eKbNUM4mIiCyCIcXCbG2UsLc1/jeEBrjBVWWL6roGZJUaT0u+c5A3AHBlWiIisjoMKTKiVCowslcP6b63qwq/Gx8MAPgsLhNLPo/HleKq1g4nIiK6pTCkyMzoZiFliL8aE/t5wd3JDgAQda4Aj28+ifJq86spn84qw6WCipvaTiIiohuNIUVmzEKKnxoqWxtsWTwer84dhgB3R2QUV2HND+ekfX6+VIT7NsVi/ofHYDAISzSZiIjohmBIkZnhQe7SGilD/NXS7SPje+Gt3w4HAOw/X4D6BgNKKnVY8W0ShACKK+uQq6mxWLuJiIi6GkOKzDirbDFzmC+8XFSI6ONp9tjo3h5QO9iiorYeZ3I0+PO2JBRV6KTHLxdxvgoREd06GFJk6N/zRyL+xUh4uqjMttsoFZjQ1wsA8MTmkziYWgR7WyUG+boCANKLKlGpq8fCT07graiLN73dREREXYkhRYYUitaXxJ/Y3xhSyqv1AIC/3j1YOk05vagSm2LScPhiEf4VfQlCcI4KERF1X+0OKYcPH8bs2bPh7+8PhUKBHTt2/OIxMTExGDVqFFQqFfr164fNmzd3oKkEALf185L+PS88EAsjeqFvTxcAQFphJX4607SeSlljkCEiIuqO2h1SqqqqMHz4cLz77ru/av+MjAzMmjULd9xxBxITE/Hcc89h8eLF2Lt3b7sbS0BvTycsui0ED40Jwmv3hkKhUKCvtzGkHLtciisl1dK+2aXVrT0NERGR7LX72j0zZ87EzJkzf/X+7733HkJCQvDmm28CAAYPHowjR45g48aNmD59entf3uopFAr87TdDzLb16enc4r7ZZdUIC3TDscul6NPTGT5qh5vRRCIioi5xwy8wGBcXh8jISLNt06dPx3PPPdfqMTqdDjpd01krWq0WAKDX66HXd90Qhum5uvI5LcHRxvz+xL6eOJpegitFlXh9z3lsOpSBsb174KtFY1o8/lapQ2ewBqyBCevAGgCsgUln69DZ+t3wkJKfnw8fHx+zbT4+PtBqtaipqYGjo+N1x6xbtw5r1qy5bvu+ffvg5OTU5W2Miorq8ue82W7zUSKhWIGnBjfgXFkRACXei7mISr1xEu6JK2X46addaGNO7i1Rh85iDVgDE9aBNQBYA5OO1qG6unPTDm54SOmIVatWYcWKFdJ9rVaLoKAgTJs2DWq1usteR6/XIyoqClOnToWdnV2XPa8l3A1A32CAnY0S2xJysC/nrBRQTMIn3QnfxiGfnYm5sFEq8Jswv1uqDh3FGrAGJqwDawCwBiadrYNpJKSjbnhI8fX1RUFBgdm2goICqNXqFntRAEClUkGlUl233c7O7oZ8WG7U895sprfQu/FsHxNbpQL1BoHsMh2CPF0Rda4Af/5vCgDgziF+cGo88FapQ2ewBqyBCevAGgCsgUlH69DZ2t3wdVIiIiIQHR1tti0qKgoRERE3+qWtVlCPpiGxgT6umDygJwDgXJ4W2+Kzser7ZOnxzBKuUktERPLU7p6UyspKpKWlSfczMjKQmJgIDw8PBAcHY9WqVcjJycHnn38OAHj66afxzjvv4Pnnn8cTTzyBAwcO4Ntvv8VPP/3Ude+CzPi5NZ3FMzzIDW6Odoi+APz9p/PX7ZtRXIXBPi2fHURERGRJ7e5JiY+Px8iRIzFy5EgAwIoVKzBy5Ei89NJLAIC8vDxkZWVJ+4eEhOCnn35CVFQUhg8fjjfffBMfffQRTz++gWxtlBjsZ5y788j43gjxMh/+efL2PtIqtZklXEuFiIjkqd09KVOmTGlzufWWVpOdMmUKTp8+3d6Xok747PExKK2uwyBfNarq6qXtfbycsWrmIPwnJh0HLhTiCod7iIhIpmR5dg91nrfaAd6NZ/I0X+ztN2F+UCgU6O1p3HalmCGFiIjkiRcYtAI9XVTo5WmcTDtvdBAASPevlFQjOUeDqhbW21nxbSKmbTyESl399Q8SERHdYOxJsQIKhQLfPhWBmroGBHkYw0lvL2NPSmlVHe577ziUChuctzmPV+8NAwAcu1yC70/lAACSsssxsdmFDYmIiG4G9qRYCR+1gxRMAMBFZZ5PDUKBL45nQ1Nj7FL5V/Ql6bEsXqiQiIgsgCHFig3ydZVuPVTGydApORqcvFKK2PQSaT9eTZmIiCyBIcWK/X3uMDw+sTe+eHw0erkYQ0pidrnUi2LqbWFPChERWQJDihUb3dsDL88eCncnOwQ3hpTPYq/g50vFsFEq8FxkfwBNPSmZJVU4l9u56zAQERH9WgwpBABSSCms0AEA7hsZgAl9jZNls8tqUKCtxW/+fQT3/ucoSip1FmsnERFZD4YUAgAEXbMy/srpAxHkYbwAZGlVHf66PRkVtfXQ1Rtwlr0pRER0EzCkEABAZQOEBhiX0n9hxiB4qx3g6mCHHk7GK1juP18o7XuxoMIibSQiIuvCdVJIsuH+UFwqrsasUD9pW5CHE8qqNQAAVwdbVNTW40I+QwoREd147EkhSZ+ezvhNmD8UCoW0zbT4m6ezPV6cNRhAU0/KscsleP67JDz1RTw01S0sWUtERNQJ7EmhNs0Y6oufLxbhtXuHYYCPcV2ViwUV+F9SLpZvPQ3TtSYn9svBwojelmsoERHdchhSqE2zh/tLFyVsMAiobJWo1Rvw7NfmV7VOytYAERZqJBER3ZI43EO/yDT8Y6NUwKfxysqAMcB88Eg4AODM1XJU1Orx8s4UPPRBHCpqOfxDRESdw5BC7TJ9qA8A4Hfjg/HPB0dgRLA7AOBSYSXu2HAIn8Vl4tjlUsQ1W1afiIioIzjcQ+3ywoxBeGR8bwR7GifUers6wM/NAXmaWhQ3W+Qtp7zGUk0kIqJbBHtSqF1sbZRSQDExnQEEAAHuxgXgrpY1hRRNtR5bjmehuq7+5jSSiIhuCQwp1GkT+noCAJQKYPGkEABATrOQcu9/juIv25PxWWymRdpHRETdE4d7qNOWTOoDg0HgvlGBSCusBABkllbjlR/Ooby6DpeLqwAAe1Ly8MyUvhBCQFOjh7uTvSWbTUREMseQQp3mrLLFimkDAQA1+gYAwPk8Lc7nmV/jp1ZvgMEgsPybRPyQlIstS8ZJFzFMK6zE+t3nsWRSH4zr43lz3wAREckSh3uoSwX0cLxum2/jacuXiyvxjz0X8ENSLgDg4AXj9YD0DQZEvnUI+88XYt3uCzevsUREJGsMKdSl1A52UDs0ddC9fn8YYv/vTjjZ20DfIPD+4cvSYxnF1QCAD39u2laorb15jSUiIlljSKEu591swbfRvXtAqVSgn7fLdfulFVZACIEv45om1NY1GG5KG4mISP4YUqjLlVXVSf8O8XIGALOQ8sj4XgCMk2sTs8uRq2nqPSmurENt47wWIiKybgwp1OV6uqqkf5uW1Pd3a5qrsmB8MNyd7CAE8EHj8E/kYG842dsAAPI0tcgsqcLz3yXhclHlTWw5ERHJCUMKdbmND47AqGB3/PeZCdK2YQFq6d8DfVzRv7FnZXdKPgAgcrAP/BsXgsstr8HMt3/Gt/FXsXbX+ZvYciIikhOGFOpyg/3U+P73ExHeq4e0bfpQX7x27zDsXj4JCoUC/bxdpccUCuDOQd5SSNmZmIPqOuOQz7lc42nM5dV1+Op4JrJLq2/iOyEiIkviOil0UygUCiwY10u637/ZHJU/Rg6At9oBAe7GCbffxl+VHquqa4AQAn/dnoKfkvOgUAArpw/E76f0u3mNJyIii2BPClnErDA/jA3xwOrZQ/DsXf0BmM9bMdHU6JGSo8Wes8ZhISGArSeypcerdPUQQtycRhMR0U3FkEIW4aN2wLdPReCxiSHSNj/3ppAyoa8nejVeyHDND2fRYBDSGULZZdWo1TfgREYpRryyD6/vTb25jSciopuiQyHl3XffRe/eveHg4IBx48bhxIkTre67efNmKBQKsy8HB4dW9yfr5efW9LmYNzoQ/XoaQ0l8ZhkAYMXUAXBzNJ4VdLmoCo9+cgL6BoFNMekWaS8REd1Y7Q4p33zzDVasWIGXX34Zp06dwvDhwzF9+nQUFha2eoxarUZeXp70lZnJq+HS9fr2bJqnMn2oL/r5NN0P8nDEtCE+Um/K8YwS6TpBAGAwcMiHiOhW0+6Q8tZbb2HJkiV4/PHHMWTIELz33ntwcnLCJ5980uoxCoUCvr6+0pePj0+nGk23Jl83B2x9cjz2PDcJTva2CPF0lh57enJf2Noopd6VN64Z4slvYzn9BgYYIqJuqV0hpa6uDgkJCYiMjGx6AqUSkZGRiIuLa/W4yspK9OrVC0FBQZgzZw7Onj3b8RbTLW18H08M8jWuqdL8FOYHwgMBNK1cazpF2SSjuKrF50srrMS4tdFY8W3iDWgtERHdSO06Bbm4uBgNDQ3X9YT4+PjgwoWWr147cOBAfPLJJwgLC4NGo8GGDRswYcIEnD17FoGBgS0eo9PpoNPppPtarXGtDL1eD71e354mt8n0XF35nN2RXOvQ28MBWxaNgb+7A5TCAL3egN6eTfNW3BxtMdRPjdjLpbiUr0EPRxv8MzoNj0/ohdG9eqDBILDw4+MortTh+1M5eO2ewbC1aTmXy7UGNxNrYMQ6sAYAa2DS2Tp0tn4K0Y7zN3NzcxEQEIDY2FhERERI259//nkcOnQIx48f/8Xn0Ov1GDx4MObPn49XX321xX1Wr16NNWvWXLd9y5YtcHJy+rXNpVtQSS3wymljtr7DzwAFgAN5SozwMCCpVAEBBQKcBJ4f3oAj+Qpsy7CRjl01vB6+/PgQEd001dXVePjhh6HRaKBWq3/5gGu0qyfFy8sLNjY2KCgoMNteUFAAX1/fX/UcdnZ2GDlyJNLS0lrdZ9WqVVixYoV0X6vVIigoCNOmTevQm2yNXq9HVFQUpk6dCjs7uy573u6mO9XBYBD4IP0w8rU6vPzwZBxJK8GBneeQWNrUQ5JTrUDIyEn4x1enATTNVek5YCTuDvNr8TnX7jqPotxMbHg8UvY1uFG60+fgRmIdWAOANTDpbB1MIyEd1a6QYm9vj/DwcERHR2Pu3LkAAIPBgOjoaCxbtuxXPUdDQwOSk5Nx9913t7qPSqWCSqW6brudnd0N+bDcqOftbrpLHb7//UTU1RvQ28sZBRVNXYlujnbwc3PAhfwKPPftGeRqauHlosKdg3ri2/ir+OxYNj47lo1CbS0W3RaCxZP6AACizxfgs+NXAdjgpToBXyf51+BG6i6fgxuNdWANANbApKN16Gzt2n12z4oVK/Dhhx/is88+w/nz5/HMM8+gqqoKjz/+OABg4cKFWLVqlbT/K6+8gn379uHy5cs4deoUfve73yEzMxOLFy/uVMPJevm7O6K3l/HMn77Nltf/x/1h+OPUAQCAy40TaR+f2BsjgowTcJOyy5GUXY48TS0+/PmytFLtp0evSM9hulZQbHoxLhZU3PD3QkR0owkhuu3K3O2+ds+DDz6IoqIivPTSS8jPz8eIESOwZ88eaTJtVlYWlMqm7FNWVoYlS5YgPz8fPXr0QHh4OGJjYzFkyJCuexdktbxcVNgwbziUCmDGMF/UNxjw4OggXCiogJ/aAY9O6I20wkppf3sbJeoaDCjQ6pCrqUVFrR5H0oqlx1NytVDY2ODxT0/CR63CsVV3QaFQAAB2nM5BnqYWT0/uI227VmFFLS4VVGJiP68b+8aJiH6ljfsv4cPDl/HdMxEY6u9m6ea0S4cuMLhs2bJWh3diYmLM7m/cuBEbN27syMsQ/Sqm05MBwNZGiX88EGb2+ECfpisuLxgfjITMMpy5qkFCZhl2ns4x2zcxW4P/ns4FABRodcgorkKfni748PBlvLbrPABgYj9PhAa4YcW3SajVN+Cdh0fBRmkMLU9/kYBTWeXYsmQcJvRlUCGiG6uwohb5mlqEBbq3+HiBthbvxaSjrsGA7adyul1I4bV76JbnaG+D344OxBA/NZbf1R+jgo3DP+8cuIToC4WwVSrw6j3Gnr0DqUXILKmWjk3MLsfprDIpoABAQmYZErPLsf10Dnan5ONSoXFYKK2wEqeyygEAxy6X3qR3R0S3KoPBfJimolaP6rp66X51XT3u+08s5rx7FGeulrf4HB8fyUBdgwEApF7jr09k4YnNJ1FeXQcAOHSxCAmZpbIcEmJIIavw+gPDsWv5JLg72WNU4yJxFwuMw0ALxgVjVqj52j9qB2Mn4+mscnx9IsvssYTMMuxo1gNz5qoGALAzsWlbcis/MIiI9A0GvLwzBZ/HXWl1n/oGA+a8exSRbx1Crb4BhRW1uGPDIcz61xHoG0PHOwfScLWsBkIA3yVcxeGLRdgUk45dyXkAjL0oXx5rugzNhfwKJGSW4qWdKThwoRBbT2ZDCIFXfzyH+zfFYWdi7g193x3BkEJWp/lKtgN9XLFi2kC4OjTNQO/j5YzX7g0FABxNL8aPZ4zf8M9F9gcAnMgoxQ+N2wAg+aoGDQZh9g1+5qoGQgj8kJSLtbvOmy3NH5degkWbTyK7tKnH5lq1+gZ8fSILJZW6VvchoptLCIH4K6Wo1Te0uk9FrR6r/3cWhy4WtbrP1yey8FlcJl7aeVYKHLHpxWY/E6IvFCI5R4P0oioczyjFW/suorjSOAR97HIJMoqr8OHPl6X9P4/LxMJPTuAfey7g91+dQlJ2OdbuOo/qugaMDHbHED/j8h33b4qDvsH482hnYi5OZZUjrbASDnZK3DXYu1P1uREYUsjq+Ls54J7h/hgX4oEtS8bBzdEYUH7XrwFTBnhhy5LxGBnsDsB4teXqugb08XLG4kl9oFQAhRU6lFbVSc93JkeDt/dfRFZpNVxVtrBVKlBSVYcfz+ThD1+fxgeHL+PQReMFOGvqGjD/w2OIvlCIjfsvAjBOyH34w2MoaHb9oWe/Po1V3ydjw76LN6kqRLcOg0Egt7ymzX0uFlTg918l4EJ+6+t41DUAUecKpe/3jVEX8cB7cXjtJ+Pw78+XijBhXTQOpjZdYPev21OwOfYKFn92EkII5Gtq8coP55CYXS617fO4pt6N9KJKJGaX4+EPj+OxT09IQy5bjjf14H5wOB3fxGdL93cl5+GNvRegbxCY1N987ltPV+PyHS/tTMHOxFwoFMAr9wzDpAFN+ykVgK1SgfN5Wrz20zkAwN2hfmZ/rMkFQwpZHYVCgX/NH4lvnoqAp0vTejxjegp8+Mgo+Lo5IMDdEd6uTY89PbkvXFS20g8AAHhsQm8AxlOb/3XAuDjh3+8dhoG+xom6f/j6tLTviYwyADD7yyf+inFuy3PfJCI2vQRfNP7gyi2vwb5zxgUTTUNNe1LyMX3jYZzIkM9cl8KK2jZ7g6h7Sswuh7bGMkvBp+RoEJde0uY+X8RdwZOfx6OituU2Vunq8cgnxzFh/QEcuFDQ4j4Gg8A97xzBruR8vLTTeC257aevYsxr+6XvseQcDdYn2eD3Xydi5bYkpORo8G5MOgDj0K6uvgH/ir6EXE0t3j9k3H4uV4v/JRl7VPUNAnvPFuCODTH45GgGXt6ZAgD48nim2RmH53K1iD5vbGd6URXO5mpxsaAChy819cQcTSuBEECfnsalF74+kY1dyflQKoAXZw3BExNDAABTBvbEvx4aCQBIahyG/t24XggNdMP8McEY5OuK2wf0xIcLR2PKwJ4AIM2je3B0UJt1txSGFKIWKBQKvP5AGBbdFoL/PhOB344xfgP/tvEbedFtIXh5tvlp9AvGBWPOiIAWZ9nHXynFnpR8/Cv6krQtq7QaD77fdGHO4xnGH87/PtC0GrO9jRIJmWV4+ssEpBZU4O3o1ntWSip1WPPDWaTkaKRtFbX66ybDxV8pNeu1uZa+wYCV25Lw5r6mK01f+xy6+gbc8++jmPn2z2a9SteqqNUjrbBpvZn6BoPUvU1d52yuBi/tTGnz/yIhsxSRbx2SfiEWamvxxt4LZsdEnSvA3HeP4vnvU9p8vYTMMuQ09lSUVtXh4Q+P4aNmAbwll4sqkacxHiOEwJbjWTh+uSmQaKr1ePD9ODzy8fFWe0HO5mrwt51nse9cAXa0MH9CCIEln8fjaJrxeX9MMg7LbovPxoKPjqGowjh8ujMpB7V64+fwREYpSqvq8MdvklBUocOb+1JRqavH0q+TUKIznrUXfaEQ//f9GWnYVltbjy+PZeHklTLpOUoqdfhPjPlK6k9/mYCaxqGhpKsavLwzRQpFJufztDh8qWkZhE+OZuDxT09CCGB8Hw/Y2RjbYG+rxCePjoGHs72070NjgzHQ1xV/nj4Abz80ApsWhGN8Hw/09jRe/yPA3REvzBwEAOjt5Yw9z92Oz58Yi7sG++CJ20LgojLOvRvip8bYEI8Wa25pDClErZgy0Bt/+80QhPdq+uZdfld/HPjTZPztN0OgUCgweYDxr5FBvq5Yfc9QAMDdob5wtLPB7OH++O5p4zWu4huDRr1BYO4If4QFGk8D1NU3/cI+nVWOPSl5ZhN16xoMuH9TrHQ//kpZq+Phq75PxqdHr+Clxr/Yvj91FaGr95l1LW+Lz8YD78XhmS8TWn3fu1IKsC3hKv59IA35mloUVtRiwvoDWLrllLRP9PlC5GtrUamrR/yVlnt36hsMmP/hMUzbeBhnc43Bacnn8Ri3NrrNX6bdkcEg8OKOZKzclgSD4eaeIVFaVYdZ/zqCz+My8enRDADARz9fRsS6aOkvdiEEXtp5FmmFlfjgsDFMvLbrPN49mI5/H2gKzl80TrI8dLEY1fVASo4WE9ZFY2vjZ7K6rh7Pfn0a92+Kxe8+Og4hBDYfzUBsegn+uf9SqwH0SnEVZr79M+a+exS6+gZEny/EX7Yn4/dfnZLqtS0hG1V1Dag3CHyXcBWRbx1C+KtRWPxZPDQ1etTVG/DijqbwdPBCIYQQ+PuP5/DYpydQqavH/vOFiG3WExObXoLiSh1e2nkWR9NK8M3JLBRoa6XhGpMnP49vamtJFTbsTUWephaeKgEftbH3NCVHC1ulArMaL61hGiYBAIMwzu+IPm8c9pk7wl96TO1gi2APY2j4rPF78dm7+mPdfcZ5b3GXS8zOzPn+VA5yymvQx8sZ7zw8CuP7eAIw/mHU28sZL84ajEn9vbD23lCsafyZ42RvizkjAuBobwOFQoGV0wehv7cL3n5ohBRErjWhrxcSX5qKA3+ajG1PR7S69pOlMaQQtYOtjRJ9ejatcvu33wzB/80chO+emQC7xissT+rfE2fXTMe/549EeK8e0pwXAJg2xAcb5g2XfvAAwEcLRyPYwwn1BoGnvzQGgUfG9zL7y0Zlq4SLyha6egM+PXoFT2w+ibDVe3GyMSDsScmXhohOZZXjclElVnybBABY84PxL7d8TS1WfndG2qe0qq5xwm8OrpYZh22EAD46ckV63egLBfjqWBbyNLXYnZwHbWMX+7fNxseTrhqfa867R/Hg+3HSX5vfxGcjJUcLgzD+QskurcbB1CKUVtXhWONf0J1ZCVNT3faQRE1dA/ak5JudsnmtgxcKMX5tNDbsTW11H4NB4KOfL2Pv2fxW9/ku4Sq+PJaFbQlXcTzDeCrnp0czzOp0rQaDwKLNJzFt46FWhy5q6hpwNlfT5urHL+5Ilv4dl14CXX0D3jmYhjxNLbafvgrA2ENytnE15fjMMhRV6LC/8fNyvPF0+dzyGvzcOMRQbxA4V6bAfw5dRq6mFptjrwAANsWkS8MZGcXGoYmvGudOGANrGRoMAhcLKszC9HuH0qGrNy6iePBCIT5vDEMlVXVILajA+TytFJAA4F/Rl5BWWImSqjrsP1+Ahz44hsc+PYHTjUMTgPF02n8fSMNHRzIQk1qEj3/OkHoan5gYApWtEvnaWqzcliT1Zuw5m4/lW0+juLIOg/3UmBXqJ9XEpECrk97vb/sYcN/IpsBxd6gfljReTsOURU0TUl/58Rxq9A0I9nDCH+7qLx2z7M5+mD82WLo/sZ8nVkwdgKH+xuNScrQQAgjycIR948+Qvj2d8eXicfByUWHtvaF47d5h+GOkcTXt+0YF4otF4/DwuGDpZ861ZoX5IWrFZIzu3XbviOnnmXMrQUYO5Nsyom6gn7cL+jVbmt9E2bi4m0KhwEBfV2mc+x/3h8HWRokHwgOxOfYKfhPmh8ghPoi+UIisxr9WB/m64i93D8brey9Ix00d4gNPZ3t8FpeJf+y5IL3O+4fSMdDXFS//z/gXpq1SgXqDwJx3j0r7GIRxnP71vU3HAcZfaGeuluP9w5cxulcPfL14DM6VK3Ahv+kX4t6zBbiQp5We51RmGfr7uOJwszMXjl0uxdG0EiQ1Tgw8l6tFLy8nvNVs0u/xjFKzH4RnrmowdYgP5n9wDGXVdfjxD5PgaG+8YvXhi0XYdy4f/zdzcIt/BdbUNeDP3yXhpzN5ePuhEZgzIuC6ffI0NVjyeTxScrSIHOyDjx4dfd0+e1Ly8XRjj9I7B9Pwx6kDYKNUoLquHs1z086kHPz9p/NQ2Spx6m9TIQC8+sM5jO7dA/NGB6GkUoe1u5v+Mv9fknGy4pofjH9pT+jricAe119+e+vJLERfMP7lvSs5Dw+OCTZ7XFurx9S3DqFAaxyi2LJ4HCb080J6USWySqpxxyBvFFXosCu5KTyl5lfgwPlClDcGuGOXjYHpn/ubeksaGi+oWVVn/MV9Pl8LTY0eXxzLNHvfRwqUyKoy/j9fyK9ASaUO358yX/zwj98koqRZr9gfv0lEXYMBpVV1mDrEBx8uHI2rZdX476mr0j7/2JOKjMbLVgDAgo+OX9ezVt+YAJbd0Q9bTmThfONnUGWrxAcLR+NvO1KQVVqNt6KaPmNvR1+EQQBO9jZYdmc/XCyowJG0YhxMbfqspuQYn8fZ3gbvPjwSKbla/NR4uu78scFIK6yQhnBu7++JQe4FCB7sg02HjD1Uj07ojRFB7lh7byiSc8rh4WyPB8KDMH3jYWktkllhfujj5YwHwgNRWlWHRyf0xtWyGun79oUZxuGXAc0WmQSMAai/tysuFVTg2bv6S98vQR5OWDCuF6wVQwrRDfbnaQPx95/O4f9mDEKPxvHkAT6uSF49TfrLac4If3xzMgt3DPTGxodGwNHeBoN9m674PWdEANQOtlJ38V2DvBF9oRAHU4uwclsSCrQ6hHg548ExQVi/+wIqas17D3Ym5uJ/jWP44b16ICGzDG9GpeJykfGXRXxmGZKuavBNurE9kwf0xKGLRWZhBABOXinFF3GZMAjA3ckO5dV6JDT7KxQAfk4rwo9n9CipqoOHsz1Kq+qQkFlmNrSVnGNcf8b0F+yxyyW4Y5A3LuRr8eQX8ajVG+CrdsCyO/ubPXehthaLP4+X1qbZHHsFUwZ442CqcfjpsQm9ka+pxYKPjktzJvafL0BKjga6egOOphVj/thg9HRVYdM18wfOXC3HkUvF+Gf0Jdzuo8QsGIc3/rHb2Muiqzfg8MUi7E7Jx/+ScvG/pFxMH+aLtbsuoLxajx5Odiir1mN3Sp7ZhOL/JeXi0Yje+PtP5xDi5Ywnb++L9KJKvL6nqfdm++kczAsPwnPfJEJX34B3Hx6Fb09mSwEFADYdSoeNUoFHPz2BWr0B/31mAtKLjMM5g/3UuFJchQpdPd5o1iuUlF2OnYm5OJenhbO9DaYP88X3p3Kwvdk6P0IAb+y9IPWIPDW5D94/dBkZFQoATanlvUPpyCmvgbO9DZbd2R//2HMBlxqHk4YHuSMpuxz5zeY6RZ0rwDcns/CfmHToGwT6eDnjcnGVFFBMgdoUUEID3DBvdKA0Z8PBTolnpvTFg2OC8N9TV1FUocMD4YEYGdwD04f64MOfjcHhiYkhiE0vxoX8CtgqFXh59hB4ONtjYj8vafGyB8IDkZhdLg1/rbp7MPr0dIGfmyMiB/tgkK8rVkwdgLW7zksh5anbQ1B8rgBD/V2x+LYQ2NooMarxrL+HxwUDaAqV7z8Sjqe+SEBdgwH3DPeHQqHAhnnDpcf79nTB+vtCYW+rlOasOdjZSHWbOsQHy+7oJ8uzayxNIeS4xNw1tFot3NzcoNFooFarf/mAX0mv12PXrl24++67rfoql6yDPGpQXVcPJ/umvxsuFlRg2sbDAIDUv8+AytYGsWnF8HRRYaCvK+a8c0SawQ8Y/9Lu6+2COzfEQFdvwLN39UdmiflfsRP6euLxiSFY0mwM/lp9vJzw47OTMG3jYVwtM/6iDw1wQ3KzCbn2NkrsXDYRM9/+Wdo2NsQDJzJKEeDuiKIKHeoaDPho4Wj8aVsSNC2cLWL6pQ4YLwQ5c5gf/vhNohQuvFzsMX9sMM7nVcDD2Q4rpg7E/ZtikVNegx5OdtDW1putPwMY/xqOSS1EnqYWfbycEeThhEMXi+Bgp5QmSkYO9saLs4ZgyoYY2CgVGBnkbtbdb5L44p146YcL0vBGS2YO88XulHwoFMC2pyLw9JenUHzN2jZeLvYYEdQD+xsnrD49uS+2nsxCebUefXs6I72oCgoF8PJvhmB1Y+/LRwtH4+X/nUVOeQ3+cGc/vHswDddOdVl0WwiullVj79kCLL+rP45dLsHxZmd/uahsUalrCqtL7+iLiD5e+N3Hx6Vtg/3UUi+FqX5r7x2Gxz89gZiLxl/wvT2dcKXZKsz3jQrAH+7sjzs2xAAAfNQq/LDsNoxdGw3AOB/D3laJb+ObPndBHo74ctE4/HlbEk5eKcPwIHc8MbE3lm9NBGBcm2j/islQKhWIfOsQ0gorMWOoL957JLzFulfp6vHNyWxM6OeJQb5qpOZX4OMjl/HQ2GBpRemSSh3+sj0Zkwd4Y/7YIHxy9Ape/fEcRga7479PT5B6O5tLvqrBnHePYEJfL3yycCR27979q38mpOZXoKhCh9v6//rLYRRoa5FVWo3RvXrIdk5IZ382dvb3N3tSiGSieUABjL0tnz42Bj5qB6hsjUMhE5pduPCB8EAppPztN0Okx356dhLsbJUIcHfEdwlXzULKM1P6YniQu3R/bIgHZg/3x98aJyTaKwXemhcGJ3tbvP3QCBy8UIRJ/b3g5arCXW8eko5bHtkfg/3UUm+Kva0Sa+8NReRbh6SQMam/F+4a7I0xvXtgf+OEwtAAN6QWVKCu3iAFFMB4JerNsVeksfnK2noUV9aZnel05qoGOeU1COzhiK8Wj8Pq/52VuvI9ne1RUlUnTTru7emEb56KQFl1HY6mFUsBBQD2ny+UrrU0oa8npg3xaTGkPPzxSZzLM/51/vspfaXTzAFgqL8aZ3O12J1iHGpZMC4Yo3t74MnbQ7Bu9wUIYZx/FHOxCMWVdVJAAYw9EoCx9+GTR0fjmS9P4cSVUrzyY9NEzL9sT0ZhhQ4ezvZYekc/ZBRXSYsKmvwvKRdVjSEkcrAPausbpJAyZ4Q/bJVK6f/eVWWLxbf1gZujHZbe0ReaGj0mD/CGpkaPP29Lkp7jlTlDoVAo8OEjo/D597swNDwCZTUNeOarpknTD4wKRIiXMwb7qZGar8XGB0fAW+2AV+cMRVZpNVZOH4TiSh1+SMpDjb4BUwb2xOv3h8Fb7dA4/FODof5q1BuEFFIeieglhYZ7hvvjn/svYsF48+Gv5pxVtnjithDp/kBfV7z+wHCzfTxdVHj/kaZhvkcjesHbVYXJA3u2GFAAIDTQDYdW3gEvFxUUivadhTbQ11VafuDX8lE7wEft0K5jrA1DCpGM3TGo9RUgHxwTjPJqPcb39cSYZhPkens5S/+eMrAnvFzsoVAo8OdpAzCpv/FspNWzh+BKSTVemDEIDULgzX2pqKlrwOIBemlCX3gvD+nMJiEEerqqUFShw70jA/D7KX0BGOfYbIpJx/r7Q9G3Z9PrAsDrD4RBoVBgxjA/7D9fiMkDemLDvOG4880Y1DUO/bz3u3BpXogQwOzh/vj73GHYFJOO9w6lw95WielDffFDUq40V+apyX3Ry9M4tHUwtQhujnbYsXQiln19WpoXs+6+MPR0VaGnqwo/PnsbtDX1GOKvxks7UvD96RzsPWsMDbOH+2N0sxWI548NgoeTHd6NuYxzeRVQKICND47A3aF++PxYJsqr9Zg2xAfr7gvF+HXR0DcIRPTxxF/uHgwAePL2vpg/NhjZpTXo6+2M9bsvYHPsFYR4OmPl9IH44OfLOHNVgycm9saKqQPhaG+D52cMxKOfnEBVXYM0BFLYeKrs05P7wMHOBn+dZZyfc/uAnrhjoDdCV++VTqf1Uasw1F+NSl093j90GQN8XLDuvlDEpBbhv6euopenE95+aKQ01Lhy+iDp/Wpq9Bjb2wPDg9zwwoxBsG02EdPLwTg0WKUXcLa3gUGYh+EvFo1FeXUd+nkbfzE/EtFbOtbf3RE//OE21OobMCyg6YJ27k72cHcytsPOxrhe0fk8LX43vmnOxbI7+pmdHttVbG2UmD3c/xf3C2o8E0evb19IoRuDwz1WPswBsA7ArV0DXX0DbJVKqfegJVfLqqHX1+NM3MFWaxB/pRSpBRV4aExwq8/18ZEMfHU8E2/9dgRGNOuxKarQSWFp/e4LeO9QOu4dGYCND47A+LXRyNfWQmWrxJnV06CytUGlrh6fxV7B9KG+8FarMH5tNKrrGuBgp8Txv0TCzdEOQgj8LykXoQFu6NPTBccvl+B3Hx/Hg2OC8Pe5oS22L7OkCr/51xFU6OoxwMcF/31mAlxUtnjqiwQUVerw+RNjkV9ehan/NE48fu3eYdKkxZjUQiRklmHpHf3gYGeDHadzcKmwAn+4sz8c7GxarW19g0H65V9Xb0BNXQPcnMzrezZXg/W7L2DOiAB8ePgyUgsqMKm/Fz57fGyLf/U/82WC1IvzwSPhmDbUF4BxIbb+3sazNYQQOJurRd+eLtKk5F/r2u+HnPIaqGyV8Gq2+OGt7lb+mdAeHO4hohvKNFTUlsAeTtDr9TjTxj6je3v84imNi24LwaJm3fAmzVfq/cOd/TAuxENaznvtfcPwyZErWH3PEKmtLipbLL2jn3TMvSMD8NXxLNwd6ied0q1QKMzO7BnXxxPJq6dDZdv6ygq9PJ1x8sVINBiE2dlGHyxsGhZw8HTG4wMaMDp8FGYND5S2TxnojSkDm3q25o68/qyiljTvnbC3VcK+hfYN9XfDF4vGAQBCvJzxQ1Iu/nBnv1aHJZZH9oeNUoHHJ4aYXYuqeTBUKBRmvRidEeDu2CXPQ9ReDClEdFM5q2zNhrHuHOSDOwf5tHGE8WyMft4uuG9kYJv7tdWj0Z59RngKTBvSdptulPBePcyCR0sG+arxzsOjblKLiCyHIYWIZM9FZYvHJ17fQ0NEtzauOEtERESyxJBCREREssSQQkRERLLEkEJERESyxJBCREREssSQQkRERLLEkEJERESyxJBCREREssSQQkRERLLEkEJERESyxJBCREREssSQQkRERLLEkEJERESy1C2ugiyEAABotdoufV69Xo/q6mpotVrY2dl16XN3J6wDawCwBiasA2sAsAYmna2D6fe26fd4e3WLkFJRUQEACAoKsnBLiIiIqL0qKirg5ubW7uMUoqPx5iYyGAzIzc2Fq6srFApFlz2vVqtFUFAQsrOzoVaru+x5uxvWgTUAWAMT1oE1AFgDk87WQQiBiooK+Pv7Q6ls/wyTbtGTolQqERgYeMOeX61WW/WH0IR1YA0A1sCEdWANANbApDN16EgPigknzhIREZEsMaQQERGRLFl1SFGpVHj55ZehUqks3RSLYh1YA4A1MGEdWAOANTCxdB26xcRZIiIisj5W3ZNCRERE8sWQQkRERLLEkEJERESyxJBCREREsmTVIeXdd99F79694eDggHHjxuHEiROWblKHrFu3DmPGjIGrqyu8vb0xd+5cpKammu0zZcoUKBQKs6+nn37abJ+srCzMmjULTk5O8Pb2xsqVK1FfX2+2T0xMDEaNGgWVSoV+/fph8+bNN/rt/WqrV6++7j0OGjRIery2thZLly6Fp6cnXFxccP/996OgoMDsObp7DXr37n1dDRQKBZYuXQrg1vwcHD58GLNnz4a/vz8UCgV27Nhh9rgQAi+99BL8/Pzg6OiIyMhIXLp0yWyf0tJSLFiwAGq1Gu7u7li0aBEqKyvN9jlz5gwmTZoEBwcHBAUF4fXXX7+uLdu2bcOgQYPg4OCA0NBQ7Nq1q8vfb2vaqoNer8cLL7yA0NBQODs7w9/fHwsXLkRubq7Zc7T0+Vm/fr3ZPnKuwy99Fh577LHr3t+MGTPM9unun4VfqkFLPx8UCgXeeOMNaR9ZfQ6Eldq6dauwt7cXn3zyiTh79qxYsmSJcHd3FwUFBZZuWrtNnz5dfPrppyIlJUUkJiaKu+++WwQHB4vKykppn8mTJ4slS5aIvLw86Uuj0UiP19fXi2HDhonIyEhx+vRpsWvXLuHl5SVWrVol7XP58mXh5OQkVqxYIc6dOyf+/e9/CxsbG7Fnz56b+n5b8/LLL4uhQ4eavceioiLp8aeffloEBQWJ6OhoER8fL8aPHy8mTJggPX4r1KCwsNDs/UdFRQkA4uDBg0KIW/NzsGvXLvHXv/5VfP/99wKA2L59u9nj69evF25ubmLHjh0iKSlJ3HPPPSIkJETU1NRI+8yYMUMMHz5cHDt2TPz888+iX79+Yv78+dLjGo1G+Pj4iAULFoiUlBTx9ddfC0dHR/H+++9L+xw9elTY2NiI119/XZw7d068+OKLws7OTiQnJ9/wGgjRdh3Ky8tFZGSk+Oabb8SFCxdEXFycGDt2rAgPDzd7jl69eolXXnnF7PPR/OeI3OvwS5+FRx99VMyYMcPs/ZWWlprt090/C79Ug+bvPS8vT3zyySdCoVCI9PR0aR85fQ6sNqSMHTtWLF26VLrf0NAg/P39xbp16yzYqq5RWFgoAIhDhw5J2yZPniyWL1/e6jG7du0SSqVS5OfnS9s2bdok1Gq10Ol0Qgghnn/+eTF06FCz4x588EExffr0rn0DHfTyyy+L4cOHt/hYeXm5sLOzE9u2bZO2nT9/XgAQcXFxQohbowbXWr58uejbt68wGAxCiFv/c3DtD2WDwSB8fX3FG2+8IW0rLy8XKpVKfP3110IIIc6dOycAiJMnT0r77N69WygUCpGTkyOEEOI///mP6NGjh1QDIYR44YUXxMCBA6X7v/3tb8WsWbPM2jNu3Djx1FNPdel7/DVa+uV0rRMnTggAIjMzU9rWq1cvsXHjxlaP6U51aC2kzJkzp9VjbrXPwq/5HMyZM0fceeedZtvk9DmwyuGeuro6JCQkIDIyUtqmVCoRGRmJuLg4C7asa2g0GgCAh4eH2favvvoKXl5eGDZsGFatWoXq6mrpsbi4OISGhsLHx0faNn36dGi1Wpw9e1bap3nNTPvIqWaXLl2Cv78/+vTpgwULFiArKwsAkJCQAL1eb9b+QYMGITg4WGr/rVIDk7q6Onz55Zd44oknzC7MaQ2fA5OMjAzk5+ebtdfNzQ3jxo0z+393d3fH6NGjpX0iIyOhVCpx/PhxaZ/bb78d9vb20j7Tp09HamoqysrKpH26S10A488JhUIBd3d3s+3r16+Hp6cnRo4ciTfeeMNsqO9WqENMTAy8vb0xcOBAPPPMMygpKZEes7bPQkFBAX766ScsWrTousfk8jnoFhcY7GrFxcVoaGgw+0EMAD4+Prhw4YKFWtU1DAYDnnvuOUycOBHDhg2Ttj/88MPo1asX/P39cebMGbzwwgtITU3F999/DwDIz89vsR6mx9raR6vVoqamBo6Ojjfyrf2icePGYfPmzRg4cCDy8vKwZs0aTJo0CSkpKcjPz4e9vf11P5B9fHx+8f2ZHmtrH7nUoLkdO3agvLwcjz32mLTNGj4HzZna3FJ7m78fb29vs8dtbW3h4eFhtk9ISMh1z2F6rEePHq3WxfQcclJbW4sXXngB8+fPN7to3LPPPotRo0bBw8MDsbGxWLVqFfLy8vDWW28B6P51mDFjBu677z6EhIQgPT0df/nLXzBz5kzExcXBxsbG6j4Ln332GVxdXXHfffeZbZfT58AqQ8qtbOnSpUhJScGRI0fMtj/55JPSv0NDQ+Hn54e77roL6enp6Nu3781u5g0xc+ZM6d9hYWEYN24cevXqhW+//VZWvzhvlo8//hgzZ86Ev7+/tM0aPgfUNr1ej9/+9rcQQmDTpk1mj61YsUL6d1hYGOzt7fHUU09h3bp1t8Ty8A899JD079DQUISFhaFv376IiYnBXXfdZcGWWcYnn3yCBQsWwMHBwWy7nD4HVjnc4+XlBRsbm+vO7CgoKICvr6+FWtV5y5Ytw48//oiDBw8iMDCwzX3HjRsHAEhLSwMA+Pr6tlgP02Nt7aNWq2UZAtzd3TFgwACkpaXB19cXdXV1KC8vN9un+f/5rVSDzMxM7N+/H4sXL25zv1v9c2Bqc1vf676+vigsLDR7vL6+HqWlpV3y2ZDTzxRTQMnMzERUVJRZL0pLxo0bh/r6ely5cgXArVMHkz59+sDLy8vs828tn4Wff/4Zqampv/gzArDs58AqQ4q9vT3Cw8MRHR0tbTMYDIiOjkZERIQFW9YxQggsW7YM27dvx4EDB67rhmtJYmIiAMDPzw8AEBERgeTkZLNvUNMPsSFDhkj7NK+ZaR+51qyyshLp6enw8/NDeHg47OzszNqfmpqKrKwsqf23Ug0+/fRTeHt7Y9asWW3ud6t/DkJCQuDr62vWXq1Wi+PHj5v9v5eXlyMhIUHa58CBAzAYDFKIi4iIwOHDh6HX66V9oqKiMHDgQPTo0UPaR851MQWUS5cuYf/+/fD09PzFYxITE6FUKqUhkFuhDs1dvXoVJSUlZp9/a/gsAMae1vDwcAwfPvwX97Xo56Bd02xvIVu3bhUqlUps3rxZnDt3Tjz55JPC3d3d7KyG7uKZZ54Rbm5uIiYmxuyUserqaiGEEGlpaeKVV14R8fHxIiMjQ+zcuVP06dNH3H777dJzmE49nTZtmkhMTBR79uwRPXv2bPHU05UrV4rz58+Ld999V1an3/7pT38SMTExIiMjQxw9elRERkYKLy8vUVhYKIQwnoIcHBwsDhw4IOLj40VERISIiIiQjr8VaiCE8Uy14OBg8cILL5htv1U/BxUVFeL06dPi9OnTAoB46623xOnTp6WzVtavXy/c3d3Fzp07xZkzZ8ScOXNaPAV55MiR4vjx4+LIkSOif//+ZqedlpeXCx8fH/HII4+IlJQUsXXrVuHk5HTdKZe2trZiw4YN4vz58+Lll1++qacgt1WHuro6cc8994jAwECRmJho9nPCdIZGbGys2Lhxo0hMTBTp6eniyy+/FD179hQLFy7sNnVoqwYVFRXiz3/+s4iLixMZGRli//79YtSoUaJ///6itrZWeo7u/ln4pe8HIYynEDs5OYlNmzZdd7zcPgdWG1KEEOLf//63CA4OFvb29mLs2LHi2LFjlm5ShwBo8evTTz8VQgiRlZUlbr/9duHh4SFUKpXo16+fWLlypdn6GEIIceXKFTFz5kzh6OgovLy8xJ/+9Ceh1+vN9jl48KAYMWKEsLe3F3369JFeQw4efPBB4efnJ+zt7UVAQIB48MEHRVpamvR4TU2N+P3vfy969OghnJycxL333ivy8vLMnqO710AIIfbu3SsAiNTUVLPtt+rn4ODBgy1+/h999FEhhPE05L/97W/Cx8dHqFQqcdddd11Xm5KSEjF//nzh4uIi1Gq1ePzxx0VFRYXZPklJSeK2224TKpVKBAQEiPXr11/Xlm+//VYMGDBA2Nvbi6FDh4qffvrphr3va7VVh4yMjFZ/TpjW0ElISBDjxo0Tbm5uwsHBQQwePFisXbvW7Be4EPKuQ1s1qK6uFtOmTRM9e/YUdnZ2olevXmLJkiXX/WHa3T8Lv/T9IIQQ77//vnB0dBTl5eXXHS+3z4FCCCHa1/dCREREdONZ5ZwUIiIikj+GFCIiIpIlhhQiIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSJYYUIiIikiWGFCIiIpIlhhQiIiKSpf8HkaNzWeYAcsQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot([i[\"step\"] for i in record[\"train\"][::50]], [i[\"loss\"] for i in record[\"train\"][::50]], label=\"train\")\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T07:36:10.207406Z",
     "iopub.status.busy": "2025-01-25T07:36:10.207176Z",
     "iopub.status.idle": "2025-01-25T07:36:11.129623Z",
     "shell.execute_reply": "2025-01-25T07:36:11.129060Z",
     "shell.execute_reply.started": "2025-01-25T07:36:10.207389Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_674/3607945331.py:26: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(\"checkpoints/text_generation_lstm/best.ckpt\", map_location=\"cpu\"))\n",
      "  0%|          | 0/1000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All: a"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 1/1000 [00:00<02:04,  8.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nd I come by him.\n",
      "\n",
      "PETRUCHIO:\n",
      "Come on, i' God's name, speak; for what life's I have worn\n",
      "That you may conquer or soun"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 118/1000 [00:00<00:01, 628.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d this one:\n",
      "Come, what letters of Hereford till God's name,\n",
      "In most I trust my want of water:\n",
      "And, for the senators, the mighty prince,\n",
      "L"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 255/1000 [00:00<00:00, 947.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "end two the bridgling throne and the fair staring\n",
      "Than at your walls, as mine hath cheepts to do,\n",
      "Without a subject spent, it makes hi"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 39%|███▉      | 389/1000 [00:00<00:00, 1095.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m to it,\n",
      "The more than his or soul and tears are gone.\n",
      "\n",
      "CLIFFORD:\n",
      "The king is nothing: I say the tractless in\n",
      "the service of your dau"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 522/1000 [00:00<00:00, 1177.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ghter to his will and that with her\n",
      "sworn.\n",
      "\n",
      "PROSPERO:\n",
      "Dost thou, that e'er thou wast but ladies' brawls,\n",
      "And what they say that she did los"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 661/1000 [00:00<00:00, 1248.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "e his life?\n",
      "\n",
      "GLOUCESTER:\n",
      "You may deny that for my counsels,\n",
      "Even thus the very time Aufidius caps it before\n",
      "The second cur is but a drunken"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 800/1000 [00:00<00:00, 1291.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " made\n",
      "As thou canst gradle carried and sullen venom,\n",
      "And on athanting them from France speaks.\n",
      "\n",
      "Nurse:\n",
      "Now, afore God! methinks I h"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 93%|█████████▎| 931/1000 [00:00<00:00, 1258.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ear gently for thee.\n",
      "\n",
      "AUTOLYCUS:\n",
      "No, good shepherd, will you not hure"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:00<00:00, 1106.37it/s]\n"
     ]
    }
   ],
   "source": [
    "def generate_text(model, start_string, max_len=1000, temperature=1.0, stream=True):\n",
    "    input_eval = torch.Tensor([char2idx[char] for char in start_string]).to(dtype=torch.int64, device=device).reshape(1, -1)\n",
    "    hidden = None\n",
    "    text_generated = []\n",
    "    model.eval()\n",
    "    pbar = tqdm(range(max_len))\n",
    "    print(start_string, end=\"\")\n",
    "    with torch.no_grad():\n",
    "        for i in pbar:\n",
    "            logits, hidden = model(input_eval, hidden=hidden)\n",
    "            # 温度采样\n",
    "            logits = logits[0, -1, :] / temperature\n",
    "            # using multinomial to sampling\n",
    "            probs = F.softmax(logits, dim=-1)\n",
    "            idx = torch.multinomial(probs, 1).item()\n",
    "            input_eval = torch.Tensor([idx]).to(dtype=torch.int64, device=device).reshape(1, -1)\n",
    "            text_generated.append(idx)\n",
    "            if stream:\n",
    "                print(idx2char[idx], end=\"\", flush=True)\n",
    "    return \"\".join([idx2char[i] for i in text_generated])\n",
    "\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "# load checkpoints\n",
    "model.load_state_dict(torch.load(\"checkpoints/text_generation_lstm/best.ckpt\", map_location=\"cpu\"))\n",
    "start_string = \"All: \"\n",
    "res = generate_text(model, start_string, max_len=1000, temperature=0.5, stream=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
